{
 "nbformat": 4,
 "nbformat_minor": 0,
 "metadata": {
  "colab": {
   "name": "sb3_highway_dqn.ipynb",
   "provenance": [],
   "collapsed_sections": []
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.5"
  },
  "kernelspec": {
   "name": "python3",
   "language": "python",
   "display_name": "Python 3"
  },
  "accelerator": "GPU",
  "pycharm": {
   "stem_cell": {
    "cell_type": "raw",
    "source": [],
    "metadata": {
     "collapsed": false
    }
   }
  }
 },
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "5eeje4O8fviH",
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "# Highway with SB3's DQN\n",
    "\n",
    "##  Warming up\n",
    "We start with a few useful installs and imports:"
   ]
  },
  {
   "cell_type": "code",
   "metadata": {
    "id": "bzMSuJEOfviP",
    "pycharm": {
     "is_executing": false,
     "name": "#%%\n"
    }
   },
   "source": [
    "# Install environment and agent\n",
    "!pip install highway-env\n",
    "!pip install stable-baselines3\n",
    "\n",
    "# Environment\n",
    "import gym\n",
    "import highway_env\n",
    "\n",
    "# Agent\n",
    "from stable_baselines3 import DQN\n",
    "\n",
    "# Visualization utils\n",
    "%load_ext tensorboard\n",
    "import sys\n",
    "from tqdm.notebook import trange\n",
    "!pip install tensorboardx gym pyvirtualdisplay\n",
    "!apt-get install -y xvfb python-opengl ffmpeg\n",
    "!git clone https://github.com/eleurent/highway-env.git 2> /dev/null\n",
    "sys.path.insert(0, '/content/highway-env/scripts/')\n",
    "from utils import record_videos, show_videos"
   ],
   "execution_count": null,
   "outputs": []
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%% md\n"
    },
    "id": "_wACJRDjqP-f"
   },
   "source": [
    "## Training\n",
    "Run tensorboard locally to visualize training."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "%tensorboard --logdir \"highway_dqn\""
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%% \n"
    }
   }
  },
  {
   "cell_type": "code",
   "metadata": {
    "pycharm": {
     "name": "#%% \n"
    },
    "id": "Y5TOvonYqP-g"
   },
   "source": [
    "model = DQN('MlpPolicy', \"highway-fast-v0\",\n",
    "                policy_kwargs=dict(net_arch=[256, 256]),\n",
    "                learning_rate=5e-4,\n",
    "                buffer_size=15000,\n",
    "                learning_starts=200,\n",
    "                batch_size=32,\n",
    "                gamma=0.8,\n",
    "                train_freq=1,\n",
    "                gradient_steps=1,\n",
    "                target_update_interval=50,\n",
    "                exploration_fraction=0.7,\n",
    "                verbose=1,\n",
    "                tensorboard_log=\"highway_dqn/\")\n",
    "model.learn(int(2e4))\n"
   ],
   "execution_count": null,
   "outputs": []
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "n2Bu_Pqop0E7"
   },
   "source": [
    "## Testing\n",
    "\n",
    "Visualize a few episodes"
   ]
  },
  {
   "cell_type": "code",
   "metadata": {
    "id": "xOcOP7Of18T2",
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "source": [
    "env = gym.make(\"highway-fast-v0\")\n",
    "env = record_videos(env)\n",
    "for episode in trange(3, desc=\"Test episodes\"):\n",
    "    obs, done = env.reset(), False\n",
    "    while not done:\n",
    "        action, _ = model.predict(obs, deterministic=True)\n",
    "        obs, reward, done, info = env.step(action)\n",
    "env.close()\n",
    "show_videos()\n"
   ],
   "execution_count": null,
   "outputs": []
  }
 ]
}